#This code produces figures 2, 3, and 4 in the main paper as well as all the figures for n=40 topics with covariates that are reported in the appendix

if(!dir.exists("figs")){dir.create("figs")}
dir.create("figs/k40")

library(tidyverse)
library(arabicStemR)
library(tidytext)
library(topicmodels)
library(lubridate)
library(ggplot2)
library(quanteda)
library(stm)
library(stats)
library(ggthemes)
library(ggpubr)

rm(list = ls())

# Reading data ------------------------------------------------------------
load("k40.RData")

gamma = tidy(topic_model, matrix = "gamma")

meta = meta %>% mutate(document = article) %>% select(- article)

gamma = left_join(gamma, meta, by = "document")
gc()

# Naming topics -----------------------------------------------------------
topics = tibble(topic = 1:40, `Topic Label` = c("communication", "assad", "isis", "temperature", "elections", 
                                                "weather", "foreign fighters", "united states", "announcements", "conspiracies and plots", 
                                                "election winners", "legislation", "terrorist attacks", "speeches", "victims", 
                                                "gulf", "accident", "media", "europe", "yemen", 
                                                "culture and sports", "capitulation", "economy", "diplomacy", "religion", 
                                                "transport", "natural resources", "national unity", "lebanon", "finance", 
                                                "bureaucracy", "education", "fighting terrorism", "Israel and Palestine", "ba'th party",
                                                "Iraq", "Turkey", "russia", "Iran", "regional politics"))
gamma = gamma %>% ungroup %>% left_join(topics)

gamma = gamma %>% select(-topic) %>% rename(topic = `Topic Label`)

gamma = gamma %>% 
  select(document, topic, date, gamma) %>% 
  filter(nchar(topic) > 2) %>% 
  group_by(topic, date) %>% 
  summarise(avg = mean(gamma))

gamma2 = gamma %>% 
  spread(topic, avg)

gamma3 = gamma2 %>% 
  mutate(week = floor_date(date, "week")) %>% 
  select(date, week, everything()) %>% 
  group_by(week) %>% 
  mutate_if(.predicate = is.numeric, .funs = mean)


names(gamma3)[3:ncol(gamma3)] = paste0(names(gamma3)[3:ncol(gamma3)], "_fit")

gamma3 = gamma3 %>% left_join(gamma2)

plot_fun = function(y1){
  y2 = paste0(y1, "_fit")
  title = tools::toTitleCase(y1)
  
  df = gamma3[, c("date", y1, y2)] %>% as.data.frame()
  
  p = ggplot() + 
    geom_line(aes(x = df$date, y = df[,y1], colour = y1)) + 
    geom_line(aes(x = df$date, y = df[,y2], colour = y2)) + 
    geom_vline(xintercept = dmy("15-03-2011"), linetype = 2) + 
    scale_colour_manual(values = c("darkgrey", "black")) + 
    labs(title = title, y = expression(paste("Average  ", gamma)), x = "Year") + theme_few() + 
    theme(legend.position = "none", title = element_text(size = 20),
          axis.text = element_text(size = 14)) + 
    coord_cartesian(ylim = c(0, 0.25)) + 
    scale_x_date(date_breaks = "2 year", date_labels = "%y") + 
    scale_y_continuous(breaks = c(0, 0.10, 0.20))
  p %>% ggsave(filename = paste0("figs/k40/", y1, ".pdf"), width = 6, height = 5)
  return(p)
}



plot_names = names(gamma3)[!str_detect(names(gamma3), "fit|date|week")]

for(i in 1:length(plot_names)){
  plot_fun(y1 = plot_names[i])
}
